knitr::opts_chunk$set(fig.align="center")
library(rstanarm)
library(tidyverse)
library(tidybayes)
library(modelr)
library(ggplot2)
library(magrittr)
library(emmeans)
library(bayesplot)
library(brms)
library(gganimate)
theme_set(theme_light())
In our experiment, we used a visualization recommendation algorithm (composed of one search algorithm and one oracle algorithm) to generate visualizations for the user on one of two datasets. We then measured the user’s time to complete each of four tasks: 1. Find Extremum 2. Retrieve Value 3. Prediction 4. Exploration
Given a search algorithm (bfs or dfs), an oracle (compassql or dziban), and a dataset (birdstrikes or movies), we would like to predict the time it takes the average user to complete each task. In addition, we would like to know if the choice of search algorithm and oracle has any meaningful impact on a user’s completion time for each of these four tasks,
time_data = read.csv('processed_completion_time.csv')
time_data <- time_data %>%
mutate(
dataset = as.factor(dataset),
oracle = as.factor(oracle),
search = as.factor(search),
task = as.factor(task)
)
time_data$log_completion_time = log(time_data$completion_time)
task_list = c("1. Find Extremum",
"2. Retrieve Value",
"3. Prediction",
"4. Exploration")
seed = 12
Units on these are log seconds. We choose to work in log space in order to prevent our model from predicting completion times of less than zero seconds.
model <- brm(
formula = bf(log_completion_time ~ dataset * oracle * search * task),
prior = prior(normal(5.69, 0.69), class = "Intercept"),
chains = 2,
cores = 2,
iter = 2500,
warmup = 1000,
data = time_data,
file = "model_time"
)
In the summary table, we want to see Rhat values close to 1.0 and Bulk_ESS in the thousands.
summary(model)
## Family: gaussian
## Links: mu = identity; sigma = identity
## Formula: log_completion_time ~ dataset * oracle * search * task
## Data: time_data (Number of observations: 236)
## Samples: 2 chains, each with iter = 2500; warmup = 1000; thin = 1;
## total post-warmup samples = 3000
##
## Population-Level Effects:
## Estimate Est.Error
## Intercept 5.45 0.17
## datasetmovies -0.20 0.24
## oracledziban -0.01 0.24
## searchdfs 0.21 0.24
## task2.RetrieveValue -0.16 0.25
## task3.Prediction 0.83 0.24
## task4.Exploration 1.13 0.24
## datasetmovies:oracledziban 0.12 0.34
## datasetmovies:searchdfs -0.29 0.34
## oracledziban:searchdfs -0.39 0.34
## datasetmovies:task2.RetrieveValue 0.15 0.34
## datasetmovies:task3.Prediction 0.21 0.33
## datasetmovies:task4.Exploration 0.23 0.34
## oracledziban:task2.RetrieveValue 0.20 0.34
## oracledziban:task3.Prediction 0.01 0.34
## oracledziban:task4.Exploration -0.09 0.34
## searchdfs:task2.RetrieveValue -0.17 0.35
## searchdfs:task3.Prediction 0.16 0.33
## searchdfs:task4.Exploration -0.16 0.34
## datasetmovies:oracledziban:searchdfs 0.15 0.48
## datasetmovies:oracledziban:task2.RetrieveValue -0.16 0.48
## datasetmovies:oracledziban:task3.Prediction 0.12 0.47
## datasetmovies:oracledziban:task4.Exploration -0.03 0.47
## datasetmovies:searchdfs:task2.RetrieveValue 0.43 0.48
## datasetmovies:searchdfs:task3.Prediction 0.36 0.46
## datasetmovies:searchdfs:task4.Exploration 0.08 0.47
## oracledziban:searchdfs:task2.RetrieveValue 0.27 0.49
## oracledziban:searchdfs:task3.Prediction 0.11 0.48
## oracledziban:searchdfs:task4.Exploration 0.48 0.48
## datasetmovies:oracledziban:searchdfs:task2.RetrieveValue -0.33 0.68
## datasetmovies:oracledziban:searchdfs:task3.Prediction -0.42 0.66
## datasetmovies:oracledziban:searchdfs:task4.Exploration -0.12 0.67
## l-95% CI u-95% CI Rhat
## Intercept 5.11 5.78 1.00
## datasetmovies -0.68 0.27 1.00
## oracledziban -0.50 0.47 1.00
## searchdfs -0.26 0.68 1.00
## task2.RetrieveValue -0.62 0.33 1.00
## task3.Prediction 0.36 1.30 1.00
## task4.Exploration 0.66 1.61 1.00
## datasetmovies:oracledziban -0.55 0.79 1.00
## datasetmovies:searchdfs -0.96 0.36 1.00
## oracledziban:searchdfs -1.08 0.28 1.00
## datasetmovies:task2.RetrieveValue -0.53 0.81 1.00
## datasetmovies:task3.Prediction -0.46 0.86 1.00
## datasetmovies:task4.Exploration -0.44 0.90 1.00
## oracledziban:task2.RetrieveValue -0.47 0.88 1.00
## oracledziban:task3.Prediction -0.67 0.68 1.00
## oracledziban:task4.Exploration -0.76 0.57 1.00
## searchdfs:task2.RetrieveValue -0.86 0.49 1.00
## searchdfs:task3.Prediction -0.51 0.80 1.00
## searchdfs:task4.Exploration -0.83 0.52 1.00
## datasetmovies:oracledziban:searchdfs -0.77 1.07 1.00
## datasetmovies:oracledziban:task2.RetrieveValue -1.12 0.77 1.00
## datasetmovies:oracledziban:task3.Prediction -0.80 1.05 1.00
## datasetmovies:oracledziban:task4.Exploration -0.94 0.91 1.00
## datasetmovies:searchdfs:task2.RetrieveValue -0.49 1.40 1.00
## datasetmovies:searchdfs:task3.Prediction -0.55 1.29 1.00
## datasetmovies:searchdfs:task4.Exploration -0.82 1.00 1.00
## oracledziban:searchdfs:task2.RetrieveValue -0.69 1.21 1.00
## oracledziban:searchdfs:task3.Prediction -0.82 1.04 1.00
## oracledziban:searchdfs:task4.Exploration -0.46 1.39 1.00
## datasetmovies:oracledziban:searchdfs:task2.RetrieveValue -1.64 0.98 1.00
## datasetmovies:oracledziban:searchdfs:task3.Prediction -1.71 0.82 1.00
## datasetmovies:oracledziban:searchdfs:task4.Exploration -1.48 1.18 1.00
## Bulk_ESS Tail_ESS
## Intercept 691 1360
## datasetmovies 647 1197
## oracledziban 605 1314
## searchdfs 663 1218
## task2.RetrieveValue 681 1238
## task3.Prediction 798 1468
## task4.Exploration 885 1527
## datasetmovies:oracledziban 586 1233
## datasetmovies:searchdfs 693 1192
## oracledziban:searchdfs 635 1227
## datasetmovies:task2.RetrieveValue 704 1422
## datasetmovies:task3.Prediction 761 1399
## datasetmovies:task4.Exploration 848 1257
## oracledziban:task2.RetrieveValue 643 1165
## oracledziban:task3.Prediction 691 1550
## oracledziban:task4.Exploration 780 1610
## searchdfs:task2.RetrieveValue 658 1370
## searchdfs:task3.Prediction 820 1533
## searchdfs:task4.Exploration 894 1550
## datasetmovies:oracledziban:searchdfs 673 1354
## datasetmovies:oracledziban:task2.RetrieveValue 692 1301
## datasetmovies:oracledziban:task3.Prediction 693 1440
## datasetmovies:oracledziban:task4.Exploration 787 1478
## datasetmovies:searchdfs:task2.RetrieveValue 759 1533
## datasetmovies:searchdfs:task3.Prediction 829 1707
## datasetmovies:searchdfs:task4.Exploration 859 1217
## oracledziban:searchdfs:task2.RetrieveValue 644 1265
## oracledziban:searchdfs:task3.Prediction 767 1529
## oracledziban:searchdfs:task4.Exploration 830 1361
## datasetmovies:oracledziban:searchdfs:task2.RetrieveValue 774 1351
## datasetmovies:oracledziban:searchdfs:task3.Prediction 812 1602
## datasetmovies:oracledziban:searchdfs:task4.Exploration 899 1523
##
## Family Specific Parameters:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma 0.44 0.02 0.40 0.49 1.00 2237 2102
##
## Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
Trace plots help us check whether there is evidence of non-convergence for model.
plot(model)
In our pairs plots, we want to make sure we don’t have highly correlated parameters (highly correlated parameters means that our model has difficulty differentiating the effect of such parameters).
pairs(
model,
pars = c("b_Intercept",
"b_datasetmovies",
"b_oracledziban",
"b_searchdfs",
"b_task2.RetrieveValue",
"b_task3.Prediction",
"b_task4.Exploration"),
fixed = TRUE
)
Using draws from the posterior, we can visualize parameter effects and average response. Be sure to apply an exponential transform to our log-transformed times to make it interpretable! The thicker, shorter line represents the 95% credible interval, while the thinner, longer line represents the 50% credible interval.
draw_data <- time_data %>%
add_fitted_draws(model, seed = seed, re_formula = NA)
for (task_name in task_list) {
draw_data_sub <- subset(draw_data, task == task_name)
plot <- ggplot(draw_data_sub, aes(
x = exp(.value),
y = condition,
fill = dataset,
alpha = 0.5
)) + stat_halfeye(.width = c(.95, .5)) +
labs(x = "Time (Seconds)", y = "Condition")
plot
filename = gsub("^.*\\.","", task_name )
filename = gsub(" ", "_", filename)
filename = paste("time", filename, sep = "")
ggsave(
file = paste(filename, ".png", sep = ""),
plot = plot,
path = "../plots/posterior_draws/time"
)
fit_info <- draw_data_sub %>% group_by(search, oracle, dataset) %>% mean_qi(exp(.value), .width = c(.95, .5))
fit_info
write.csv(fit_info,
paste("../plot_data/posterior_draws/time/", filename, ".csv", sep = ""),
row.names = FALSE)
}
## Saving 7 x 5 in image
## Saving 7 x 5 in image
## Saving 7 x 5 in image
## Saving 7 x 5 in image
Next, we want to see if there is any significant difference in completion time between the two search algorithms (bfs and dfs) and the two oracles (dzbian and compassql).
predictive_data <- time_data %>%
add_predicted_draws(model, seed = seed, re_formula = NA)
diff_in_search_prediction <- predictive_data %>%
group_by(search, task, dataset, .draw) %>%
summarize(time = weighted.mean(exp(.prediction))) %>%
compare_levels(time, by = search) %>%
rename(diff_in_time = time)
## `summarise()` regrouping output by 'search', 'task', 'dataset' (override with `.groups` argument)
diff_in_search_prediction_plot <- diff_in_search_prediction %>%
ggplot(aes(x = diff_in_time, y = task)) +
xlab(paste0("Time Difference (Seconds) (",diff_in_search_prediction[1,'search'],")")) +
ylab("Task")+
stat_halfeye(.width = c(.95, .5)) +
geom_vline(xintercept = 0, linetype = "longdash") +
theme_minimal() + scale_y_discrete(limits = rev(levels(diff_in_search_prediction$task)))
ggsave(file="search_time_differences.png", plot=diff_in_search_prediction_plot, path = "../plots/comparisons/time", width = 7, height = 7)
diff_in_search_prediction_plot
We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero.
diff_in_search_prediction_intervals <- diff_in_search_prediction %>% group_by(search, task) %>% mean_qi(diff_in_time, .width = c(.95, .5))
write.csv(diff_in_search_prediction_intervals, "../plot_data/comparisons/time/search_time_differences.csv", sep="", row.names = FALSE)
## Warning in write.csv(diff_in_search_prediction_intervals, "../plot_data/
## comparisons/time/search_time_differences.csv", : attempt to set 'sep' ignored
diff_in_search_prediction_intervals
## # A tibble: 8 x 8
## # Groups: search [1]
## search task diff_in_time .lower .upper .width .point .interval
## <chr> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 dfs - bfs 1. Find Extremum -17.2 -137. 115. 0.95 mean qi
## 2 dfs - bfs 2. Retrieve Value -0.0641 -115. 113. 0.95 mean qi
## 3 dfs - bfs 3. Prediction 142. -192. 493. 0.95 mean qi
## 4 dfs - bfs 4. Exploration -4.21 -404. 410. 0.95 mean qi
## 5 dfs - bfs 1. Find Extremum -17.2 -58.9 22.2 0.5 mean qi
## 6 dfs - bfs 2. Retrieve Value -0.0641 -37.0 36.6 0.5 mean qi
## 7 dfs - bfs 3. Prediction 142. 30.1 253. 0.5 mean qi
## 8 dfs - bfs 4. Exploration -4.21 -143. 134. 0.5 mean qi
Let’s do the above, but split it by datasets
diff_in_search_prediction_plot_split_by_dataset <- diff_in_search_prediction_plot+facet_grid(. ~ dataset) + aes(fill= dataset)
ggsave(file="split_by_dataset_search_time_differences.png", plot=diff_in_search_prediction_plot_split_by_dataset, path = "../plots/comparisons/time", width = 7, height = 7)
diff_in_search_prediction_plot_split_by_dataset
Check intervals
diff_in_search_prediction_intervals_split_by_dataset <- diff_in_search_prediction %>% group_by(search, dataset, task) %>% mean_qi(diff_in_time, .width = c(.95, .5))
write.csv(diff_in_search_prediction_intervals_split_by_dataset, "../plot_data/comparisons/time/search_time_differences_split_by_dataset.csv", row.names = FALSE)
diff_in_search_prediction_intervals_split_by_dataset
## # A tibble: 16 x 9
## # Groups: search, dataset [2]
## search dataset task diff_in_time .lower .upper .width .point .interval
## <chr> <fct> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 dfs - … birdstr… 1. Find… 8.02 -124. 135. 0.95 mean qi
## 2 dfs - … birdstr… 2. Retr… -4.12 -126. 113. 0.95 mean qi
## 3 dfs - … birdstr… 3. Pred… 165. -164. 507. 0.95 mean qi
## 4 dfs - … birdstr… 4. Expl… 72.6 -324. 484. 0.95 mean qi
## 5 dfs - … movies 1. Find… -42.4 -145. 52.0 0.95 mean qi
## 6 dfs - … movies 2. Retr… 3.99 -104. 112. 0.95 mean qi
## 7 dfs - … movies 3. Pred… 119. -222. 473. 0.95 mean qi
## 8 dfs - … movies 4. Expl… -81.0 -468. 280. 0.95 mean qi
## 9 dfs - … birdstr… 1. Find… 8.02 -36.0 51.8 0.5 mean qi
## 10 dfs - … birdstr… 2. Retr… -4.12 -43.4 34.7 0.5 mean qi
## 11 dfs - … birdstr… 3. Pred… 165. 55.5 271. 0.5 mean qi
## 12 dfs - … birdstr… 4. Expl… 72.6 -63.5 203. 0.5 mean qi
## 13 dfs - … movies 1. Find… -42.4 -73.2 -11.8 0.5 mean qi
## 14 dfs - … movies 2. Retr… 3.99 -30.1 38.9 0.5 mean qi
## 15 dfs - … movies 3. Pred… 119. 7.29 233. 0.5 mean qi
## 16 dfs - … movies 4. Expl… -81.0 -203. 42.6 0.5 mean qi
predictive_data <- time_data %>%
add_predicted_draws(model, seed = seed, re_formula = NA)
diff_in_oracle_prediction <- predictive_data %>%
group_by(oracle, task, dataset, .draw) %>%
summarize(time = weighted.mean(exp(.prediction))) %>%
compare_levels(time, by = oracle) %>%
rename(diff_in_time = time)
## `summarise()` regrouping output by 'oracle', 'task', 'dataset' (override with `.groups` argument)
diff_in_oracle_prediction_plot <- diff_in_oracle_prediction %>%
ggplot(aes(x = diff_in_time, y = task)) +
xlab(paste0("Time Difference (Seconds) (",diff_in_oracle_prediction[1,'oracle'],")")) +
ylab("Task")+
stat_halfeye(.width = c(.95, .5)) +
geom_vline(xintercept = 0, linetype = "longdash") +
theme_minimal() + scale_y_discrete(limits = rev(levels(diff_in_oracle_prediction$task)))
ggsave(file="oracle_time_differences.png", plot=diff_in_oracle_prediction_plot, path = "../plots/comparisons/time", width = 7, height = 7)
diff_in_oracle_prediction_plot
We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero.
diff_in_oracle_prediction_intervals <- diff_in_oracle_prediction %>% group_by(oracle, task) %>% mean_qi(diff_in_time, .width = c(.95, .5))
write.csv(diff_in_oracle_prediction_intervals, "../plot_data/comparisons/time/oracle_time_differences.csv", sep="", row.names = FALSE)
## Warning in write.csv(diff_in_oracle_prediction_intervals, "../plot_data/
## comparisons/time/oracle_time_differences.csv", : attempt to set 'sep' ignored
diff_in_oracle_prediction_intervals
## # A tibble: 8 x 8
## # Groups: oracle [1]
## oracle task diff_in_time .lower .upper .width .point .interval
## <chr> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 dziban - comp… 1. Find Ext… -25.8 -157. 87.3 0.95 mean qi
## 2 dziban - comp… 2. Retrieve… 16.0 -95.4 137. 0.95 mean qi
## 3 dziban - comp… 3. Predicti… -78.2 -432. 268. 0.95 mean qi
## 4 dziban - comp… 4. Explorat… 3.61 -395. 402. 0.95 mean qi
## 5 dziban - comp… 1. Find Ext… -25.8 -65.8 17.1 0.5 mean qi
## 6 dziban - comp… 2. Retrieve… 16.0 -22.8 54.3 0.5 mean qi
## 7 dziban - comp… 3. Predicti… -78.2 -193. 39.7 0.5 mean qi
## 8 dziban - comp… 4. Explorat… 3.61 -127. 134. 0.5 mean qi
Let’s do the above, but split it by datasets
diff_in_oracle_prediction_plot_split_by_dataset <- diff_in_oracle_prediction_plot+facet_grid(. ~ dataset) + aes(fill= dataset)
ggsave(file="split_by_dataset_oracle_time_differences.png", plot=diff_in_oracle_prediction_plot_split_by_dataset, path = "../plots/comparisons/time", width = 7, height = 7)
diff_in_oracle_prediction_plot_split_by_dataset
Check intervals
diff_in_oracle_prediction_intervals_split_by_dataset <- diff_in_oracle_prediction %>% group_by(oracle, dataset, task) %>% mean_qi(diff_in_time, .width = c(.95, .5))
write.csv(diff_in_oracle_prediction_intervals_split_by_dataset, "../plot_data/comparisons/time/oracle_time_differences_split_by_dataset.csv", row.names = FALSE)
diff_in_oracle_prediction_intervals_split_by_dataset
## # A tibble: 16 x 9
## # Groups: oracle, dataset [2]
## oracle dataset task diff_in_time .lower .upper .width .point .interval
## <chr> <fct> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 dziban -… birdst… 1. Fin… -53.7 -182. 68.9 0.95 mean qi
## 2 dziban -… birdst… 2. Ret… 32.3 -80.3 152. 0.95 mean qi
## 3 dziban -… birdst… 3. Pre… -102. -440. 235. 0.95 mean qi
## 4 dziban -… birdst… 4. Exp… -40.6 -422. 371. 0.95 mean qi
## 5 dziban -… movies 1. Fin… 2.17 -89.9 96.7 0.95 mean qi
## 6 dziban -… movies 2. Ret… -0.333 -106. 105. 0.95 mean qi
## 7 dziban -… movies 3. Pre… -54.3 -414. 284. 0.95 mean qi
## 8 dziban -… movies 4. Exp… 47.8 -324. 419. 0.95 mean qi
## 9 dziban -… birdst… 1. Fin… -53.7 -94.1 -10.9 0.5 mean qi
## 10 dziban -… birdst… 2. Ret… 32.3 -7.10 70.1 0.5 mean qi
## 11 dziban -… birdst… 3. Pre… -102. -214. 5.33 0.5 mean qi
## 12 dziban -… birdst… 4. Exp… -40.6 -170. 86.7 0.5 mean qi
## 13 dziban -… movies 1. Fin… 2.17 -28.0 33.6 0.5 mean qi
## 14 dziban -… movies 2. Ret… -0.333 -36.0 32.8 0.5 mean qi
## 15 dziban -… movies 3. Pre… -54.3 -169. 65.6 0.5 mean qi
## 16 dziban -… movies 4. Exp… 47.8 -71.2 173. 0.5 mean qi
Plot all of the posterior draws on one plot
plot <- draw_data %>% ggplot(aes(
x = exp(.value),
y = task,
fill = search,
alpha = 0.5
)) + stat_halfeye(.width = c(.95, .5)) +
labs(x = "Time (Seconds)", y = "Task") + facet_grid(. ~ dataset)
plot
## Saving 7 x 5 in image